
include("../src/qmcmary.jl")
using ..qmcmary

using Test
using LinearAlgebra


"""
计算G
"""
function rawgf(allbmats, idx)
    Nt = length(allbmats)
    Nx2 = size(allbmats[1])[1]
    Rs = I(Nx2)#allbmats[1]
    for i in Base.OneTo(idx)
        Rs = allbmats[i]*Rs
    end
    Ls = Diagonal(ones(Nx2))
    for i in Nt:-1:(idx+1)
        Ls = Ls*allbmats[i]
    end
    return inv(Diagonal(ones(Nx2)) + Rs*Ls)
end

"""
计算G
"""
function rawgft(allbmats, idx)
    # g(t,0) = <c(t) c+(0)> = B(t,0)G(0,0)
    # g(0,t) = -<c+(t) c(0)> = −(1−G(0,0))B^−1(t,0) 
    g0 = rawgf(allbmats, 0)
    if idx > 0
        bt0 = prod(reverse(allbmats[1:idx]))
    else
        bt0 = I(size(allbmats[end])[1])
    end
    gt0 = bt0*g0
    ind = I(size(allbmats[end])[1])
    g0t = -(ind-g0)*inv(bt0)
    return g0t, gt0
end


"""
计算G
"""
function rawgft2(allbmats, idx)
    # g(Nt, tau) = <c(Nt) c+(tau)> = B(Nt,tau)G(tau,tau)
    # g(tau, Nt) = -<c+(Nt) c(tau)> = −(1−G(tau,tau))B^−1(Nt,tau)
    Nt = length(allbmats)
    gtau = rawgf(allbmats, idx)
    if idx > 0
        bt0 = prod(reverse(allbmats[tau:Nt]))
    else
        bt0 = I(size(allbmats[end])[1])
    end
    gt0 = bt0*gtau
    ind = I(size(allbmats[end])[1])
    g0t = -(ind-gtau)*inv(bt0)
    return g0t, gt0
end


function gf1()
    ng = 5
    bmats1 = [fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1]]
    bmats2 = [fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1]]
    bmats3 = [fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1], fakedata(4)[1]]


    ss = ScrollSVD(bmats1)
    push!(ss, bmats2)
    push!(ss, bmats3)

    println(ss.L)

    b1 = prod(reverse(bmats1))
    b2 = prod(reverse(bmats2))
    b3 = prod(reverse(bmats3))

    bprod = b3*b2*b1
    ssbp = ss.F[end].U * Diagonal(ss.F[end].S) * ss.F[end].Vt
    println(sum(abs.(ssbp - bprod)))

    scrollR2L(ss, bmats3)
    println(ss.L)
    bprod = b2*b1
    ssbp = ss.F[2].U * Diagonal(ss.F[2].S) * ss.F[2].Vt
    println(sum(abs.(ssbp - bprod)))
    bprod = b3
    ssbp = ss.F[3].U * Diagonal(ss.F[3].S) * ss.F[3].Vt
    println(sum(abs.(ssbp - bprod)))

    gr1 = inv(Diagonal(ones(4)) + b2*b1*b3)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))

    scrollR2L(ss, bmats2)
    #scrollR2L(ss, bmats1)
    gr1 = inv(Diagonal(ones(4)) + b1*b3*b2)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))

    scrollR2L(ss, bmats1)
    gr1 = inv(Diagonal(ones(4)) + b3*b2*b1)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))

    println(ss.L)
    scrollL2R(ss, bmats1)
    gr1 = inv(Diagonal(ones(4)) + b1*b3*b2)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))
    bprod = b1
    ssbp = ss.F[1].U * Diagonal(ss.F[1].S) * ss.F[1].Vt
    println(sum(abs.(ssbp - bprod)))
    bprod = b3*b2
    ssbp = ss.F[2].U * Diagonal(ss.F[2].S) * ss.F[2].Vt
    println(sum(abs.(ssbp - bprod)))

    scrollL2R(ss, bmats2)
    bprod = b2*b1
    ssbp = ss.F[2].U * Diagonal(ss.F[2].S) * ss.F[2].Vt
    println(sum(abs.(ssbp - bprod)))
    bprod = b3
    ssbp = ss.F[3].U * Diagonal(ss.F[3].S) * ss.F[3].Vt
    println(sum(abs.(ssbp - bprod)))
    gr1 = inv(Diagonal(ones(4)) + b2*b1*b3)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))


    scrollL2R(ss, bmats1)
    gr1 = inv(Diagonal(ones(4)) + b3*b2*b1)
    gr2 = eq_green_scratch(ss)
    println(sum(abs.(gr1-gr2)))

end


function gf2()
    Nx = 6
    hk = rand(ComplexF64, Nx, Nx)
    hk += adjoint(hk)
    Ui = 1*ones(Nx)
    the = rand(Nx)*pi
    nx = zeros(Nx)
    ny = sin.(the)#[0.5sqrt(2), 0.5sqrt(2)]#
    nz = cos.(the)
    #
    Nt = 20
    Ng = 5
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    for bi in Base.OneTo(Nt)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
    end
    sp = default_splitting(Nt, hk, Ui)
    sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
    gf, ph = eq_green_scratch(ss)
    for i in Base.OneTo(Ng)
        gf = down_prog_green(gf, allbmats[Nt-i+1])
    end
    scrollR2L(ss, allbmats[Nt-Ng+1:Nt])
    gf2, ph2 = eq_green_scratch(ss)
    println(maximum(abs.(gf-gf2)))
    println(ph, " ", ph2)
    @assert all(isapprox.(gf, gf2, atol=1e-6))
    @assert all(isapprox.(ph, ph2, atol=1e-6))
    #
    scrollL2R(ss, allbmats[Nt-Ng+1:Nt])
    Gf0t, Gft0 = ueq_green_func(Nt, Ng, Nx, sslen, bmats, allbmats, ss)
    @testset "ueq gf" begin
        for t in Base.OneTo(Nt)
            tau = t-1#(t-1)*Ng
            g0t, gt0 = rawgft(allbmats, tau)
            #println(Gft0[tau+1, :, :])
            #println(gt0)
            #println(Gf0t[tau+1, :, :])
            @test all(isapprox.(Gft0[t, :, :], gt0, atol=1e-8))
            @test all(isapprox.(Gf0t[t, :, :], g0t, atol=1e-8))
        end
    end
    return
    #=
    g0 = rawgf(allbmats, 15)
    bt0 = prod(reverse(allbmats[16:20]))
    gt0 = bt0*g0
    println(gt02[1:2, 1])
    println(gt0[1:2, 1])
    #
    #scrollL2R(ss, allbmats[Nt-Ng+1:Nt])
    #这时的等时在15，ueq_scratch会得到G（Nt，15）
    gt02, ph = eq_green_scratch(ss)
    gt02 = allbmats[16]*gt02
    gt02 = allbmats[17]*gt02
    gt02 = allbmats[18]*gt02
    gt02 = allbmats[19]*gt02
    gt02 = allbmats[20]*gt02
    println(gt02[1:2, 1])
    println(gt0[1:2, 1])
    return
    #
    oldgt0 = missing
    for grpi in Base.OneTo(sslen)
        g0t2, gt02 = ueq_green_scratch(ss)
        g0t, gt0 = rawgft(allbmats, Ng*(grpi-1))#Nt-Ng*(grpi-1))
        println(gt02[1:2, 1])
        println(gt0[1:2, 1])
        #println(maximum(abs.(gt0-gt02)), " ", findfirst(ss.L))
        for ni in Base.OneTo(Ng)
            gi = Nt - Ng*(grpi) + ni
            gi = gi > Nt ? gi - Nt : gi
            #gi = gi - 1 < 1 ? Nt : gi-1
            gt02 = allbmats[gi]*gt02
        end
        #
        if !ismissing(oldgt0)
            println(maximum(abs.(oldgt0-gt02)), " ", findfirst(ss.L))
        end
        _, oldgt0 = ueq_green_scratch(ss)
        #这是会算到tau+Ng，也就是上一个
        scrollR2L(ss, [ss.B[sslen-grpi+1]])
    end
    =#
end


gf2()
